# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_ataripy
'''
已适配

训练记录：
在2号机上训练
20250102：测试分数9.4分，继续训练
20250103: 测试分数3.2分，停止训练，代码存在问题，低分会把高分覆盖掉
20250113: 训练分数7.8分，测试分数16.9分，继续训练
20250114: Learning rate: actor=0.0003, q=0.0003，训练分数9.98，测试分数21分，继续训练
20250115: Learning rate: actor=0.0003, q=0.0003，训练分数12.3，测试分数23分，继续训练
20250116：Learning rate: actor=0.0003, q=0.0003,训练分数12.7分，测试分数25.1分，继续训练
20250117: Learning rate: actor=0.0003, q=0.0003,训练分数12.45分，测试分数25.8分，继续训练
20250119:Learning rate: actor=0.0003, q=0.0003，训练分数13.210分，测试分数25.8分，继续训练
20250120：Learning rate: actor=0.0003, q=0.0003，因断电训练分数未知，测试分数25.8分，继续训练
20250121:Learning rate: actor=0.0003, q=0.0003,训练分数12.7分，测试分数25.8分，调整学习率，继续训练
20250122: Learning rate: actor=0.0003, q=0.0003，训练分数14.280分，测试分数26.9，，继续训练
20250123：Learning rate: actor=0.0003, q=0.0003,全部步数训练完毕，训练分数14.110分，测试分数26.9，重新训练，看起来模型有效果就是超参数有问题
'''

import os
import random
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
from lib import common
from stable_baselines3.common.buffers import ReplayBuffer
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
import ale_py

gym.register_envs(ale_py)

@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    # Algorithm specific arguments
    env_id: str = "ALE/Atlantis2-v5"
    """the id of the environment"""
    total_timesteps: int = 5000000
    """total timesteps of the experiments"""
    buffer_size: int = int(1e6)
    """the replay memory buffer size"""  # smaller than in original paper but evaluation is done only for 100k steps anyway
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 1.0
    """target smoothing coefficient (default: 1)"""
    batch_size: int = 64
    """the batch size of sample from the reply memory"""
    learning_starts: int = 2e4
    """timestep to start learning"""
    policy_lr: float = 3e-4
    """the learning rate of the policy network optimizer"""
    q_lr: float = 3e-4
    """the learning rate of the Q network network optimizer"""
    update_frequency: int = 4
    """the frequency of training updates"""
    target_network_frequency: int = 8000
    """the frequency of updates for the target networks"""
    alpha: float = 0.2
    """Entropy regularization coefficient."""
    autotune: bool = True
    """automatic tuning of the entropy coefficient"""
    target_entropy_scale: float = 0.89
    """coefficient for scaling the autotune entropy target"""


from collections import deque

class FrameStackWrapper(gym.Wrapper):
    def __init__(self, env, n_frames=4):
        super().__init__(env)
        self.n_frames = n_frames
        self.frames = deque([], maxlen=n_frames)

        # 假设原始观察空间是一个 Box
        obs_shape = env.observation_space.shape

        # 修改观察空间以适应堆叠的帧
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(n_frames, *obs_shape[:]),
            dtype=env.observation_space.dtype
        )

    def _get_observation(self):
        # 将帧堆叠在一起
        return np.stack(list(self.frames), axis=0)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.n_frames):
            self.frames.append(obs)
        return self._get_observation(), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(obs)
        return self._get_observation(), reward, terminated, truncated, info
    

def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array", obs_type="grayscale")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id, obs_type="grayscale")
        env = gym.wrappers.RecordEpisodeStatistics(env)

        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)
        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)
        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = FrameStackWrapper(env, 4)

        env.action_space.seed(seed)
        return env

    return thunk


def layer_init(layer, bias_const=0.0):
    nn.init.kaiming_normal_(layer.weight)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


# ALGO LOGIC: initialize agent here:
# NOTE: Sharing a CNN encoder between Actor and Critics is not recommended for SAC without stopping actor gradients
# See the SAC+AE paper https://arxiv.org/abs/1910.01741 for more info
# TL;DR The actor's gradients mess up the representation when using a joint encoder
class SoftQNetwork(nn.Module):
    def __init__(self, envs):
        super().__init__()
        obs_shape = envs.single_observation_space.shape
        self.conv = nn.Sequential(
            layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
            nn.Flatten(),
        )

        with torch.inference_mode():
            output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]

        self.fc1 = layer_init(nn.Linear(output_dim, 512))
        self.fc_q = layer_init(nn.Linear(512, envs.single_action_space.n))

    def forward(self, x):
        x = x.float()
        x = F.relu(self.conv(x / 255.0))
        x = F.relu(self.fc1(x))
        q_vals = self.fc_q(x)
        return q_vals


class Actor(nn.Module):
    def __init__(self, envs):
        super().__init__()
        obs_shape = envs.single_observation_space.shape
        self.conv = nn.Sequential(
            layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
            nn.Flatten(),
        )

        with torch.inference_mode():
            output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]

        self.fc1 = layer_init(nn.Linear(output_dim, 512))
        self.fc_logits = layer_init(nn.Linear(512, envs.single_action_space.n))

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = F.relu(self.fc1(x))
        logits = self.fc_logits(x)

        return logits

    def get_action(self, x):
        logits = self(x / 255.0)
        policy_dist = Categorical(logits=logits)
        action = policy_dist.sample()
        # Action probabilities for calculating the adapted soft-Q loss
        action_probs = policy_dist.probs
        log_prob = F.log_softmax(logits, dim=1)
        return action, log_prob, action_probs
    

def select_device(args):
    if args.cuda and torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available() and args.cuda:
        return torch.device("mps")
    return torch.device("cpu")


import ptan


def test_net(net, env, count=10, device="cpu"):
    '''
    count: 执行游戏的次数（每次都是执行到游戏结束）

    return: （平均奖励，平均步数）
    '''
    rewards = 0.0
    steps = 0
    with torch.no_grad():
        for _ in range(count):
            obs, _ = env.reset()
            while True:
                obs_v = ptan.agent.default_states_preprocessor([obs]).to(device)
                probs = net(obs_v.float())
                probs = probs.cpu().numpy()
                action = np.argmax(probs)
                obs, reward, done, trunc, _ = env.step(action)
                rewards += reward
                steps += 1
                if done or trunc:
                    break
    return rewards / count, steps / count


if __name__ == "__main__":
    import stable_baselines3 as sb3

    if sb3.__version__ < "2.0":
        raise ValueError(
            """Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1"  "ale-py==0.8.1" 
"""
        )
    args = tyro.cli(Args)
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    # if args.track:
    #     import wandb

    #     wandb.init(
    #         project=args.wandb_project_name,
    #         entity=args.wandb_entity,
    #         sync_tensorboard=True,
    #         config=vars(args),
    #         name=run_name,
    #         monitor_gym=True,
    #         save_code=True,
    #     )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = select_device(args)
    save_path = os.path.join("saves", "cleanrl-sac-atlantis2")

    # env setup
    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
    test_env = make_env(args.env_id, args.seed, 1, args.capture_video, run_name)()
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    actor = Actor(envs).to(device)
    qf1 = SoftQNetwork(envs).to(device)
    qf2 = SoftQNetwork(envs).to(device)
    qf1_target = SoftQNetwork(envs).to(device)
    qf2_target = SoftQNetwork(envs).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    # TRY NOT TO MODIFY: eps=1e-4 increases numerical stability
    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr, eps=1e-4)
    shceduler_q = optim.lr_scheduler.StepLR(q_optimizer, step_size=50000, gamma=0.9)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-4)
    shceduler_actor = optim.lr_scheduler.StepLR(actor_optimizer, step_size=50000, gamma=0.9)

    # Automatic entropy tuning
    if args.autotune:
        target_entropy = -args.target_entropy_scale * torch.log(1 / torch.tensor(envs.single_action_space.n))
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, eps=1e-4)
    else:
        alpha = args.alpha

    rb = ReplayBuffer(
        args.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        handle_timeout_termination=False,
    )
    start_time = time.time()

    # TRY NOT TO MODIFY: start the game
    obs, _ = envs.reset(seed=args.seed)

    start_steps = 0
    best_reward = 0
    if os.path.exists(save_path) and len(os.listdir(save_path)) > 0:
        # 增加加载模型的代码
        checkpoints = sorted(filter(lambda x: "epoch" in x, os.listdir(save_path)),
                             key=lambda x: int(x.split('_')[2].split('.')[0]))
        checkpoint = torch.load(os.path.join(save_path, checkpoints[-1]), map_location=device, weights_only=False)
        actor.load_state_dict(checkpoint["actor"])
        qf1.load_state_dict(checkpoint["qf1"])
        qf2.load_state_dict(checkpoint["qf2"])
        qf1_target.load_state_dict(checkpoint["qf1_target"])
        qf2_target.load_state_dict(checkpoint["qf2_target"])
        q_optimizer.load_state_dict(checkpoint["q_optimizer"])
        actor_optimizer.load_state_dict(checkpoint["actor_optimizer"])
        alpha = checkpoint["alpha"]
        log_alpha = checkpoint["log_alpha"]
        a_optimizer.load_state_dict(checkpoint["a_optimizer"])
        start_steps = checkpoint["global_step"]
        if "shceduler_q" in checkpoint:
            shceduler_q.load_state_dict(checkpoint["shceduler_q"])
        else:
            shceduler_q = optim.lr_scheduler.StepLR(q_optimizer, step_size=50000, gamma=0.9)
        if "shceduler_actor" in checkpoint:
            shceduler_actor.load_state_dict(checkpoint["shceduler_actor"])
        else:
            shceduler_actor = optim.lr_scheduler.StepLR(actor_optimizer, step_size=50000, gamma=0.9)

        print("加载模型成功")
        # 打印学习率
        print(f"global_step={start_steps}")
        print(f"Learning rate: actor={actor_optimizer.param_groups[0]['lr']}, q={q_optimizer.param_groups[0]['lr']}")
    

    with common.RewardTracker(writer, stop_reward=10000) as tracker:
        for global_step in range(start_steps, args.total_timesteps):
            # ALGO LOGIC: put action logic here
            if global_step < args.learning_starts:
                actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
            else:
                actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
                actions = actions.detach().cpu().numpy()

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, rewards, terminations, truncations, infos = envs.step(actions)
            tracker.rewards(terminations=terminations, truncations=truncations, rewards=rewards, frame=global_step)

            # TRY NOT TO MODIFY: record rewards for plotting purposes
            if "final_info" in infos:
                for info in infos["final_info"]:
                    # Skip the envs that are not done
                    if "episode" not in info:
                        continue
                    print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                    writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                    writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                    break

            # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
            real_next_obs = next_obs.copy()
            for idx, trunc in enumerate(truncations):
                if trunc:
                    real_next_obs[idx] = infos["final_observation"][idx]
            rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

            # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
            obs = next_obs

            # ALGO LOGIC: training.
            if global_step > args.learning_starts:
                if global_step % args.update_frequency == 0:
                    data = rb.sample(args.batch_size)
                    # CRITIC training
                    with torch.no_grad():
                        _, next_state_log_pi, next_state_action_probs = actor.get_action(data.next_observations)
                        qf1_next_target = qf1_target(data.next_observations)
                        qf2_next_target = qf2_target(data.next_observations)
                        # we can use the action probabilities instead of MC sampling to estimate the expectation
                        min_qf_next_target = next_state_action_probs * (
                            torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
                        )
                        # adapt Q-target for discrete Q-function
                        min_qf_next_target = min_qf_next_target.sum(dim=1)
                        next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target)

                    # use Q-values only for the taken actions
                    qf1_values = qf1(data.observations)
                    qf2_values = qf2(data.observations)
                    qf1_a_values = qf1_values.gather(1, data.actions.long()).view(-1)
                    qf2_a_values = qf2_values.gather(1, data.actions.long()).view(-1)
                    qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
                    qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
                    qf_loss = qf1_loss + qf2_loss

                    q_optimizer.zero_grad()
                    qf_loss.backward()
                    q_optimizer.step()

                    # ACTOR training
                    _, log_pi, action_probs = actor.get_action(data.observations)
                    with torch.no_grad():
                        qf1_values = qf1(data.observations)
                        qf2_values = qf2(data.observations)
                        min_qf_values = torch.min(qf1_values, qf2_values)
                    # no need for reparameterization, the expectation can be calculated for discrete actions
                    actor_loss = (action_probs * ((alpha * log_pi) - min_qf_values)).mean()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()

                    if args.autotune:
                        # re-use action probabilities for temperature loss
                        alpha_loss = (action_probs.detach() * (-log_alpha.exp() * (log_pi + target_entropy).detach())).mean()

                        a_optimizer.zero_grad()
                        alpha_loss.backward()
                        a_optimizer.step()
                        alpha = log_alpha.exp().item()

                # update the target networks
                if global_step % args.target_network_frequency == 0:
                    for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                        target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                    for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                        target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

                if global_step % 100 == 0:
                    writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
                    writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
                    writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
                    writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
                    writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
                    writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
                    writer.add_scalar("losses/alpha", alpha, global_step)
                    print("SPS:", int(global_step / (time.time() - start_time)))
                    writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
                    if args.autotune:
                        writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)
                    
                    checkpoint = {
                        "actor": actor.state_dict(),
                        "qf1": qf1.state_dict(),
                        "qf2": qf2.state_dict(),
                        "qf1_target": qf1_target.state_dict(),
                        "qf2_target": qf2_target.state_dict(),
                        "q_optimizer": q_optimizer.state_dict(),
                        "actor_optimizer": actor_optimizer.state_dict(),
                        "alpha": alpha,
                        "log_alpha": log_alpha,
                        "a_optimizer": a_optimizer.state_dict(),
                        "global_step": global_step,
                        "shceduler_actor": shceduler_actor.state_dict(),
                        "shceduler_q": shceduler_q.state_dict()
                    }

                    common.save_checkpoints(global_step, checkpoint, save_path, "cleanrl-sac")

                    # 测试并保存最好测试结果的庶数据
                    ts = time.time()
                    actor.eval()
                    rewards, steps = test_net(actor, test_env, device=device)
                    actor.train()
                    print("Test done in %.2f sec, reward %.3f, steps %d" % (
                        time.time() - ts, rewards, steps))
                    if best_reward is None or best_reward < rewards:
                        if best_reward is not None:
                            print("Best reward updated: %.3f -> %.3f" % (best_reward, rewards))
                        best_reward = rewards
                    common.save_best_model(rewards, checkpoint, save_path, f"sac-best-{global_step}", keep_best=10)




    envs.close()
    writer.close()